import numpy as np
import pandas as pd
from sklearn.cluster import KMeans



def uniform_bin(x, num_bins, alpha=0.01, mini=None, maxi=None):
    '''
    bin the continous X into a categorical one
    output range 1,2,...,num_bins
    todo: if mean+3sigma>x.max(), iteratively use mean+2.5sigma ...
    '''
    mini = min(x)-alpha if mini == None else mini
    maxi = max(x)+alpha if maxi == None else maxi

    bins = np.linspace(mini, maxi, num_bins-1)
    tx = np.digitize(x, bins)
    return tx


def quantile_bin(x_values, num_bins, eps=1e-5):
    '''
    bin a comtinous variable by sample quantiles, such that each bin has equal sample points
    output range: 1,2,...,num_bins
    note: numpy.digitze bin[i-1]<=x<bin[i], we enlarge the last bin a little to include x.max() to the *num_bins*-th bin
    ''' #  pd.qcut
    _, bins = pd.qcut(x_values, num_bins, retbins=True, duplicates='drop')
    bins[-1] += eps
    tx = np.digitize(x_values, bins)

    return tx

def quantile_bin_with_min_samples(x_values, num_bins, min_samples=1, eps=1e-5):
    _, bins = pd.qcut(x_values, num_bins, retbins=True, duplicates='drop')
    bins[-1] += eps
    tx = np.digitize(x_values, bins)
    
    # Check if any bin has fewer than min_samples samples
    unique_bins, bin_counts = np.unique(tx, return_counts=True)
    bins_to_adjust = unique_bins[bin_counts < min_samples]
    
    for bin_to_adjust in bins_to_adjust:
        # Find the minimum and maximum value in the current bin
        bin_min = bins[bin_to_adjust - 1]
        bin_max = bins[bin_to_adjust]
        
        # Find the nearest data point that belongs to the previous bin
        prev_bin_data = x_values[tx == (bin_to_adjust - 1)]
        nearest_value = prev_bin_data[np.argmin(np.abs(prev_bin_data - bin_max))]
        
        # Adjust the bin boundary
        bins[bin_to_adjust - 1] = nearest_value + eps
    
    tx = np.digitize(x_values, bins)
    return tx




def K_means(x_values, num_bins):
    '''
    bin a comtinous variable by sample quantiles, such that each bin has equal sample points
    output range: 1,2,...,num_bins
    note: numpy.digitze bin[i-1]<=x<bin[i], we enlarge the last bin a little to include x.max() to the *num_bins*-th bin
    ''' 
    x_values = x_values.reshape(-1,1)
    kmeans = KMeans(n_clusters=num_bins, random_state=0)
    tx = kmeans.fit_predict(x_values)+ 1
    return tx
